{
 "cells": [
  {
   "cell_type": "raw",
   "id": "a08b0eaa-f104-49a3-bcdf-a1b69ea31bc7",
   "metadata": {},
   "source": [
    "---\n",
    "title: calibration\n",
    "subtitle: Rigid transforms with camera calibration matrices\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03a5387b-a349-4b8e-a144-a56d30972ffa",
   "metadata": {},
   "source": [
    "An X-ray C-arm can be modeled as a pinhole camera with its own extrinsic and intrinsic matrices. \n",
    "This module provides utilities for parsing these matrices and working with rigid transforms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8634b454-9289-4857-80f8-232d913de6a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp calibration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "359cb689-f4ab-4336-a5a9-b75be9492fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f67e02b-580c-4f37-9a90-fcca5713de08",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "594553a4-4ec9-46e5-96b4-4430edc84135",
   "metadata": {},
   "source": [
    "## Rigid transformations\n",
    "\n",
    "We represent rigid transforms as $4 \\times 4$ matrices (following the right-handed convention of `PyTorch3D`),\n",
    "\n",
    "\\begin{equation}\n",
    "\\begin{bmatrix}\n",
    "    \\mathbf R^T & \\mathbf 0 \\\\\n",
    "    \\mathbf t^T & 1\n",
    "\\end{bmatrix}\n",
    "\\in \\mathbf{SE}(3) \\,,\n",
    "\\end{equation}\n",
    "\n",
    "where $\\mathbf R \\in \\mathbf{SO}(3)$ is a rotation matrix and $\\mathbf t\\in \\mathbb R^3$ represents a translation.\n",
    "\n",
    "Note that since rotation matrices are orthogonal, we have a simple closed-form equation for the inverse:\n",
    "\\begin{equation}\n",
    "\\begin{bmatrix}\n",
    "    \\mathbf R^T & \\mathbf 0 \\\\\n",
    "    \\mathbf t^T & 1\n",
    "\\end{bmatrix}^{-1} =\n",
    "\\begin{bmatrix}\n",
    "    \\mathbf R & \\mathbf 0 \\\\\n",
    "    -\\mathbf R \\mathbf t & 1\n",
    "\\end{bmatrix} \\,.\n",
    "\\end{equation}\n",
    "\n",
    "For convenience, we add a wrapper of `pytorch3d.transforms.Transform3d` that can be construced from a (batched) rotation matrix and translation vector. This module also includes the closed-form inverse specific to rigid transforms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "008781b6-de62-473f-b2d7-c2f7eb051b41",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from typing import Optional\n",
    "\n",
    "from beartype import beartype\n",
    "from diffdrr.utils import Transform3d\n",
    "from diffdrr.utils import convert as convert_so3\n",
    "from diffdrr.utils import se3_exp_map, se3_log_map\n",
    "from jaxtyping import Float, jaxtyped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86492beb-3f1f-40ac-aff6-018ba1d04067",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@beartype\n",
    "class RigidTransform(Transform3d):\n",
    "    \"\"\"Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.\"\"\"\n",
    "\n",
    "    @jaxtyped(typechecker=beartype)\n",
    "    def __init__(\n",
    "        self,\n",
    "        R: Float[torch.Tensor, \"...\"],\n",
    "        t: Float[torch.Tensor, \"... 3\"],\n",
    "        parameterization: str = \"matrix\",\n",
    "        convention: Optional[str] = None,\n",
    "        device=None,\n",
    "        dtype=torch.float32,\n",
    "    ):\n",
    "        if device is None and (R.device == t.device):\n",
    "            device = R.device\n",
    "\n",
    "        R = convert_so3(R, parameterization, \"matrix\", convention)\n",
    "        if R.dim() == 2 and t.dim() == 1:\n",
    "            R = R.unsqueeze(0)\n",
    "            t = t.unsqueeze(0)\n",
    "        assert (batch_size := len(R)) == len(t), \"R and t need same batch size\"\n",
    "\n",
    "        matrix = torch.zeros(batch_size, 4, 4, device=device, dtype=dtype)\n",
    "        matrix[..., :3, :3] = R.transpose(-1, -2)\n",
    "        matrix[..., 3, :3] = t\n",
    "        matrix[..., 3, 3] = 1\n",
    "\n",
    "        super().__init__(matrix=matrix, device=device, dtype=dtype)\n",
    "\n",
    "    def get_rotation(self, parameterization=None, convention=None):\n",
    "        R = self.get_matrix()[..., :3, :3].transpose(-1, -2)\n",
    "        if parameterization is not None:\n",
    "            R = convert_so3(R, \"matrix\", parameterization, None, convention)\n",
    "        return R\n",
    "\n",
    "    def get_translation(self):\n",
    "        return self.get_matrix()[..., 3, :3]\n",
    "\n",
    "    def inverse(self):\n",
    "        \"\"\"Closed-form inverse for rigid transforms.\"\"\"\n",
    "        R = self.get_rotation().transpose(-1, -2)\n",
    "        t = self.get_translation()\n",
    "        t = -torch.einsum(\"bij,bj->bi\", R, t)\n",
    "        return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n",
    "\n",
    "    def compose(self, other):\n",
    "        T = super().compose(other)\n",
    "        R = T.get_matrix()[..., :3, :3].transpose(-1, -2)\n",
    "        t = T.get_matrix()[..., 3, :3]\n",
    "        return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n",
    "\n",
    "    def clone(self):\n",
    "        R = self.get_matrix()[..., :3, :3].transpose(-1, -2).clone()\n",
    "        t = self.get_matrix()[..., 3, :3].clone()\n",
    "        return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n",
    "\n",
    "    def get_se3_log(self):\n",
    "        return se3_log_map(self.get_matrix())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff7984c3-a8f5-435f-b504-dce55787f517",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "def convert(\n",
    "    transform,\n",
    "    input_parameterization,\n",
    "    output_parameterization,\n",
    "    input_convention=None,\n",
    "    output_convention=None,\n",
    "):\n",
    "    \"\"\"Convert between representations of SE(3).\"\"\"\n",
    "\n",
    "    # Convert any input parameterization to a RigidTransform\n",
    "    if input_parameterization == \"se3_log_map\":\n",
    "        transform = torch.concat([transform[1], transform[0]], axis=-1)\n",
    "        matrix = se3_exp_map(transform).transpose(-1, -2)\n",
    "        transform = RigidTransform(\n",
    "            R=matrix[..., :3, :3],\n",
    "            t=matrix[..., :3, 3],\n",
    "            device=matrix.device,\n",
    "            dtype=matrix.dtype,\n",
    "        )\n",
    "    elif input_parameterization == \"se3_exp_map\":\n",
    "        pass\n",
    "    else:\n",
    "        transform = RigidTransform(\n",
    "            R=transform[0],\n",
    "            t=transform[1],\n",
    "            parameterization=input_parameterization,\n",
    "            convention=input_convention,\n",
    "        )\n",
    "\n",
    "    # Convert the RigidTransform to any output\n",
    "    if output_parameterization == \"se3_exp_map\":\n",
    "        return transform\n",
    "    elif output_parameterization == \"se3_log_map\":\n",
    "        se3_log = transform.get_se3_log()\n",
    "        log_t_vee = se3_log[..., :3]\n",
    "        log_R_vee = se3_log[..., 3:]\n",
    "        return log_R_vee, log_t_vee\n",
    "    else:\n",
    "        return (\n",
    "            transform.get_rotation(output_parameterization, output_convention),\n",
    "            transform.get_translation(),\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7968b28c-ba34-49c7-8acc-bb080c0e4556",
   "metadata": {},
   "source": [
    "## Computing a perspective projection\n",
    "\n",
    "Given an `extrinsic` and `intrinsic` camera matrix, we can compute the perspective projection of a batch of points.\n",
    "This is used for computing where fiducials in world coordinates get mapped onto the image plane."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8347585b-1a3b-4246-9fc4-d5f4dad5944b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@jaxtyped(typechecker=beartype)\n",
    "def perspective_projection(\n",
    "    extrinsic: RigidTransform,  # Extrinsic camera matrix (world to camera)\n",
    "    intrinsic: Float[torch.Tensor, \"3 3\"],  # Intrinsic camera matrix (camera to image)\n",
    "    x: Float[torch.Tensor, \"b n 3\"],  # World coordinates\n",
    ") -> Float[torch.Tensor, \"b n 2\"]:\n",
    "    x = extrinsic.transform_points(x)\n",
    "    x = torch.einsum(\"ij, bnj -> bni\", intrinsic, x)\n",
    "    z = x[..., -1].unsqueeze(-1).clone()\n",
    "    x = x / z\n",
    "    return x[..., :2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b18dbb07-f6cd-4de2-ad25-a784ad621471",
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev\n",
    "\n",
    "nbdev.nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a304b6f4-6558-4d63-99f8-dd9c82b9fddf",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
